import argparse
import math
import os
import random
from time import time

import matlab
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from numpy import linalg as LA
from scipy.io import savemat


from HiQLip.main import HiQLipsolver
from utils.mnist import NeuralNetToy, NeuralNet2, NeuralNet2_128, NeuralNet2_256, NeuralNet3, \
    NeuralNet4, NeuralNet7, NeuralNet8, NeuralNetToy2, NeuralNet5
from utils.naiveNorms import NaiveNorms
from utils.solver import GL_Solver

# Model training globals
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_size = 28 * 28
num_classes = 10
batch_size = 100
learning_rate = 0.001

# Import MNIST dataset 
train_dataset = torchvision.datasets.MNIST(root='./data',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data',
                                          train=False,
                                          transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

epsilon = 0.1
k = 40
alpha = 0.01


class LinfPGDAttack(object):
    def __init__(self, model, epsilon=epsilon, steps=k, alpha=alpha):
        self.model = model
        self.epsilon = epsilon
        self.steps = steps
        self.alpha = alpha

    def perturb(self, x_natural, y):
        x = x_natural.detach()
        x = x + torch.zeros_like(x).uniform_(-self.epsilon, self.epsilon)
        for i in range(self.steps):
            x.requires_grad_()
            with torch.enable_grad():
                logits = self.model(x)
                loss = F.cross_entropy(logits, y)
            grad = torch.autograd.grad(loss, [x])[0]
            x = x.detach() + self.alpha * torch.sign(grad.detach())
            x = torch.min(torch.max(x, x_natural - self.epsilon), x_natural + self.epsilon)
            x = torch.clamp(x, 0, 1)
        return x


def models_from_parser(args):
    if not os.path.exists('models'):
        os.makedirs('models')

    if not os.path.exists('mats'):
        os.makedirs('mats')

    model_path = "mnist_model_toy"
    mat_path = "mnist_weight_toy"
    model = NeuralNetToy(input_size, num_classes).to(device)

    if args.model == "net2_8":
        model_path = "mnist_model_toy2"
        mat_path = "mnist_weight_toy2"
        model = NeuralNetToy2(input_size, num_classes).to(device)
    elif args.model == "net2":
        model_path = "mnist_model2"
        mat_path = "mnist_weight_model2"
        model = NeuralNet2(input_size, num_classes).to(device)
    elif args.model == "net2_128":
        model_path = "mnist_model2_128"
        mat_path = "mnist_weight_model2_128"
        model = NeuralNet2_128(input_size, num_classes).to(device)
    elif args.model == "net2_256":
        model_path = "mnist_model2_256"
        mat_path = "mnist_weight_model2_256"
        model = NeuralNet2_256(input_size, num_classes).to(device)
    elif args.model == "net3":
        model_path = "mnist_model3"
        mat_path = "mnist_weight_model3"
        model = NeuralNet3(input_size, num_classes).to(device)
    elif args.model == "net4":
        model_path = "mnist_model4"
        mat_path = "mnist_weight_model4"
        model = NeuralNet4(input_size, num_classes).to(device)
    elif args.model == "net5":
        model_path = "mnist_model5"
        mat_path = "mnist_weight_model5"
        model = NeuralNet5(input_size, num_classes).to(device)
    elif args.model == "net7":
        model_path = "mnist_model7"
        mat_path = "mnist_weight_model7"
        model = NeuralNet7(input_size, num_classes).to(device)
    elif args.model == "net8":
        model_path = "mnist_model8"
        mat_path = "mnist_weight_model8"
        model = NeuralNet8(input_size, num_classes).to(device)
    elif args.model == "net8":
        model_path = "mnist_model8"
        mat_path = "mnist_weight_model8"
        model = NeuralNet8(input_size, num_classes).to(device)

    if args.adv:
        model_path = model_path + "_adv"
        mat_path = mat_path + "_adv.mat"
    else:
        mat_path = mat_path + ".mat"
    print(f"Model: {args.model}, Method: {args.method}, L2: {args.l2}, Adv: {args.adv}")
    return "models/" + model_path, "mats/" + mat_path, model


def train_model(model_path, mat_path, model, adversary, adv):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    adversary = LinfPGDAttack(model=model)

    # Train the model
    if adv:
        model.adv_train(criterion, optimizer, adversary)
    else:
        model.train(criterion, optimizer)

    model.train(criterion, optimizer)
    torch.save(model.state_dict(), model_path)

    model.evaluate()
    model.adv_evaluate(adversary)
    weights = []
    for layer in model.modules():
        if type(layer) is nn.Linear:
            weight = layer.weight.cpu().detach().numpy()
            weights.append(weight)

    data = {'weights': np.array(weights, dtype=object)}
    savemat(mat_path, data)


def OthersMethod(args):
    model_path, mat_path, model = models_from_parser(args)
    adversary = LinfPGDAttack(model=model)

    if not os.path.exists(model_path):
        train_model(model_path, mat_path, model, adversary, args.adv)

    model.load_state_dict(torch.load(model_path))
    weights = []
    for layer in model.modules():
        if type(layer) is nn.Linear:
            weight = layer.weight.cpu().detach().numpy()
            weights.append(weight)
            # print(weight.shape)
    weights_num = len(weights)
    classes = weights[-1].shape[0]
    if len(weights) > 2 and args.method == "sdp":
        print("We will use the dual program to estimate the FGL")
        args.methods = "sdp_dual"
    start_time = time()

    if args.method == "product":
        jcb_norm = 1
        for i in range(weights_num - 1):
            if args.l2:
                jcb_norm *= LA.norm(weights[i].transpose(), 2)
            else:
                jcb_norm *= LA.norm(weights[i].transpose(), 1)
        norms = []
        for i in range(classes):
            vec = weights[-1][i, :]
            if args.l2:
                norms.append(LA.norm(vec.transpose(), 2) * jcb_norm)
            else:
                norms.append(LA.norm(vec.transpose(), 1) * jcb_norm)
        print("Matrix Product Norms are: ", norms)

    if args.method == "brute" and (args.model == "net2_8" or args.model == "net2_16"):
        NN = NaiveNorms(weights[0], weights[1])
        if args.l2:
            vec = NN.BFNorms(2)
        else:
            vec = NN.BFNorms(1)
        print("Brute Force Norms are: ", vec)

    if args.method == "sampling":
        lbs = []
        for i in range(classes):
            noise = torch.rand(200000, input_size) - torch.ones(input_size) / 2
            center = torch.zeros(input_size)
            x = (10 * noise + center).to(device)
            x.requires_grad = True
            x.retain_grad()
            model(x)[:, i].sum().backward()
            if args.l2:
                norms = torch.norm(x.grad, p=2, dim=1)
            else:
                norms = torch.norm(x.grad, p=1, dim=1)
            lbs.append(torch.max(norms).item())
        print("Sampling Lower Bounds are: ", lbs)

    # if args.method == "sdp":
    #     eng = matlab.engine.start_matlab()
    #     eng.addpath(r'matlab_solver')
    #     if args.l2:
    #         lcs = eng.GeoLIP(mat_path, '2', False)
    #     else:
    #         lcs = eng.GeoLIP(mat_path, 'inf', False)
    #     print("SDP Norms are: ", lcs)
    #
    if args.method == "sdp_dual":
        eng = matlab.engine.start_matlab()
        eng.addpath(r'matlab_solver')
        if args.l2:
            lcs = eng.GeoLIP(mat_path, '2', True)
        else:
            lcs = eng.GeoLIP(mat_path, 'inf', True)
        print("SDP Norms are: ", lcs)

    if args.method == "sdp_py":
        gs = GL_Solver(weights=weights, dual=True, approx_hidden=False, approx_input=False)
        result = gs.sdp_norm(parallel=False)
        print("CVXPY norms are:", result)
    print(f'Total time: {float(time() - start_time):.5} seconds')

def multiHMC(args):
    model_path, mat_path, model = models_from_parser(args)
    adversary = LinfPGDAttack(model=model)

    if not os.path.exists(model_path):
        train_model(model_path, mat_path, model, adversary, args.adv)

    model.load_state_dict(torch.load(model_path))
    weights = []
    for layer in model.modules():
        if type(layer) is nn.Linear:
            weight = layer.weight.cpu().detach().numpy()
            weights.append(weight)
            # print(weight.shape)
    weights_num = len(weights)
    classes = weights[-1].shape[0]
    fir, sec, tir = 784, 16, 0
    sec = weights[1].shape[1]
    n = fir + sec + tir
    adj = np.zeros((n, n))

    final_weight = weights[-1][8, :]
    diagu = np.zeros((sec, sec))
    for i in range(sec):
        diagu[i][i] = final_weight[i]
    weights[0] = np.dot(weights[0].T, diagu);
    weights[0] = weights[0].T
    for i in range(sec):
        for j in range(fir):
            adj[i + 784][j] = weights[0][i][j]
            adj[j][i + 784] = weights[0][i][j]
    t = 1
    hmcmc = np.zeros(t)
    for i in range(t):
        time_begin = time()
        hmcmaxcut, hmcsol = HiQLipsolver(adj)
        hmcmc[i] = 2 * hmcmaxcut - np.sum(adj)
        time_end = time() - time_begin
        print(i, "times, HMC is:", hmcmc[i], "HMC time is:", time_end,
              math.sqrt(hmcmaxcut) )
    return hmcmc[0]

def mutilayers(args):
    norm_result = 1
    model_path, mat_path, model = models_from_parser(args)
    model.load_state_dict(torch.load(model_path))
    weights = []
    for layer in model.modules():
        if type(layer) is nn.Linear:
            weight = layer.weight.cpu().detach().numpy()
            weights.append(weight)
            # print(weight.shape)
    weights_num = []
    for i in range(len(weights)):
        weights_num.append(int(weights[i].shape[1]))
    classes = weights[-1].shape[0]


    fir, sec, tir = 64, 64, 0
    n = fir + sec + tir
    adj = np.zeros((n, n))
    final_weight = weights[-1][8, :]
    diagu = np.zeros((sec, sec))
    for i in range(sec):
        diagu[i][i] = final_weight[i]
    weights[-2] = np.dot(weights[-2].T, diagu);
    weights[-2] = weights[-2].T
    for i in range(sec):
        for j in range(fir):
            adj[i + fir][j] = weights[-2][i][j]
            adj[j][i + fir] = weights[-2][i][j]
    norm_result *= 2 * HiQLipsolver(adj)[0] - np.sum(adj)
    # norm_result *= HMCsolver(adj)[0]
    print("Last Matrix Product Norms are: ", norm_result)
    weights_num[-1] = 0

    for i in range(len(weights) - 2):
        layer_weights = weights[-i - 1]
        fir = layer_weights.shape[1];
        sec = layer_weights.shape[0];
        n = fir + sec
        adj = np.zeros((n, n))
        for i2 in range(sec):
            for j in range(fir):
                adj[i2 + fir][j] = layer_weights[i2][j]
                adj[j][i2 + fir] = layer_weights[i2][j]
        hmc = 2 * HiQLipsolver(adj)[0] - np.sum(adj)
        norm_result *= hmc
        # print(hmc)
    norm_result /= pow(2, len(weights) - 2)

    print("The net of ", network, "Norms are:  ", norm_result)


if __name__ == '__main__':
    random.seed(42)
    np.random.seed(42)

    network = 'net2_8'
    networkname = ['net2_8','net2_16', 'net2','net2_128', 'net2_256']  # ,'net2_512'
    # networkname = ['net3','net4','net5']#,'net7','net8' ,]
    methodname = "sdp_dual"

    i = 1

    # network = 'net7'
    # print("Network Type is: ", network)
    parser = argparse.ArgumentParser()
    parser.add_argument("--model",
                        nargs='?',
                        const=network,
                        default=network,
                        choices=['net2_8', 'net2_16', 'net2', 'net2_128', 'net2_256', 'net3', 'net7', 'net8'],
                        help="which model to use")
    parser.add_argument("--method",
                        nargs='?',
                        const=methodname,
                        default=methodname,
                        choices=['brute', 'product', 'sdp', 'sdp_dual', 'sdp_py', 'sampling'],
                        help="which other method to use")
    parser.add_argument('--l2',
                        action='store_true',
                        help="estimate l_2 FGL or l_inf")
    parser.add_argument('--adv',
                        action='store_true',
                        default=False,
                        help="adversarial training a network")

    args = parser.parse_args()
    time_begin = time()
    #     python mnist_eval.py --model net2 --method sdp_py

    # OthersMethod(args)
    # multiHMC(args)
    mutilayers(args)





